{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gradio as gr\n",
    "import numpy as np\n",
    "from PIL import Image, ImageDraw, ImageFont\n",
    "import matplotlib.pyplot as plt\n",
    "import cv2\n",
    "from segment_anything import sam_model_registry\n",
    "from segment_anything.predictor_sammed import SammedPredictor\n",
    "from argparse import Namespace\n",
    "import torch\n",
    "import torchvision\n",
    "import os, sys\n",
    "import random\n",
    "import warnings\n",
    "from scipy import ndimage\n",
    "import functools\n",
    "\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "args = Namespace()\n",
    "args.device = device\n",
    "args.image_size = 256\n",
    "args.encoder_adapter = True\n",
    "args.sam_checkpoint = \"pretrain_model/sam-med2d_b.pth\"  #sam_vit_b.pth  sam-med2d_b.pth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_model(args):\n",
    "    model = sam_model_registry[\"vit_b\"](args).to(args.device)\n",
    "    model.eval()\n",
    "    predictor = SammedPredictor(model)\n",
    "    return predictor\n",
    "\n",
    "\n",
    "predictor_with_adapter = load_model(args)\n",
    "args.encoder_adapter = False\n",
    "predictor_without_adapter = load_model(args)\n",
    "\n",
    "def run_sammed(input_image, selected_points, last_mask, adapter_type):\n",
    "    if adapter_type == \"SAM-Med2D-B\":\n",
    "        predictor = predictor_with_adapter\n",
    "    else:\n",
    "        predictor = predictor_without_adapter\n",
    "        \n",
    "    image_pil = Image.fromarray(input_image) #.convert(\"RGB\")\n",
    "    image = input_image\n",
    "    H,W,_ = image.shape\n",
    "    predictor.set_image(image)\n",
    "    centers = np.array([a for a,b in selected_points ])\n",
    "    point_coords = centers\n",
    "    point_labels = np.array([b for a,b in selected_points ])\n",
    "\n",
    "    masks, _, logits = predictor.predict(\n",
    "    point_coords=point_coords,\n",
    "    point_labels=point_labels,\n",
    "    mask_input = last_mask,\n",
    "    multimask_output=True \n",
    "    ) \n",
    "\n",
    "    mask_image = Image.new('RGBA', (W, H), color=(0, 0, 0, 0))\n",
    "    mask_draw = ImageDraw.Draw(mask_image)\n",
    "    for mask in masks:\n",
    "        draw_mask(mask, mask_draw, random_color=False)\n",
    "    image_draw = ImageDraw.Draw(image_pil)\n",
    "\n",
    "    draw_point(selected_points, image_draw)\n",
    "\n",
    "    image_pil = image_pil.convert('RGBA')\n",
    "    image_pil.alpha_composite(mask_image)\n",
    "    last_mask = torch.sigmoid(torch.as_tensor(logits, dtype=torch.float, device=device))\n",
    "    return [(image_pil, mask_image), last_mask]\n",
    "\n",
    "\n",
    "def draw_mask(mask, draw, random_color=False):\n",
    "    if random_color:\n",
    "        color = (random.randint(0, 255), random.randint(\n",
    "            0, 255), random.randint(0, 255), 153)\n",
    "    else:\n",
    "        color = (30, 144, 255, 153)\n",
    "\n",
    "    nonzero_coords = np.transpose(np.nonzero(mask))\n",
    "\n",
    "    for coord in nonzero_coords:\n",
    "        draw.point(coord[::-1], fill=color)\n",
    "\n",
    "def draw_point(point, draw, r=5):\n",
    "    show_point = []\n",
    "    for point, label in point:\n",
    "        x,y = point\n",
    "        if label == 1:\n",
    "            draw.ellipse((x-r, y-r, x+r, y+r), fill='green')\n",
    "        elif label == 0:\n",
    "            draw.ellipse((x-r, y-r, x+r, y+r), fill='red')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Keyboard interruption in main thread... closing server.\n"
     ]
    },
    {
     "data": {
      "text/plain": []
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "colors = [(255, 0, 0), (0, 255, 0)]\n",
    "markers = [1, 5]\n",
    "block = gr.Blocks()\n",
    "with block:\n",
    "    with gr.Row():\n",
    "        gr.Markdown(\n",
    "            '''# SAM-Med2D!🚀\n",
    "            SAM-Med2D is an interactive segmentation model based on the SAM model for medical scenarios, supporting multi-point interactive segmentation and box interaction. \n",
    "            Currently, only multi-point interaction is supported in this application. More information can be found on [**GitHub**](https://github.com/uni-medical/SAM-Med2D/tree/main).\n",
    "            '''\n",
    "        )\n",
    "        with gr.Row():\n",
    "            # select model\n",
    "            adapter_type = gr.Dropdown([\"SAM-Med2D-B\", \"SAM-Med2D-B_w/o_adapter\"], value='SAM-Med2D-B', label=\"Select Adapter\")\n",
    "            # adapter_type.change(fn = update_model, inputs=[adapter_type])\n",
    "          \n",
    "    with gr.Tab(label='Image'):\n",
    "        with gr.Row().style(equal_height=True):\n",
    "            with gr.Column():\n",
    "                # input image\n",
    "                original_image = gr.State(value=None)   # store original image without points, default None\n",
    "                input_image = gr.Image(type=\"numpy\")\n",
    "                # point prompt\n",
    "                with gr.Column():\n",
    "                    selected_points = gr.State([])      # store points\n",
    "                    last_mask = gr.State(None) \n",
    "                    with gr.Row():\n",
    "                        gr.Markdown('You can click on the image to select points prompt. Default: foreground_point.')\n",
    "                        undo_button = gr.Button('Undo point')\n",
    "                    radio = gr.Radio(['foreground_point', 'background_point'], label='point labels')\n",
    "                button = gr.Button(\"Run!\")\n",
    "        \n",
    "            gallery_sammed = gr.Gallery(\n",
    "                    label=\"Generated images\", show_label=False, elem_id=\"gallery\").style(preview=True, grid=2,object_fit=\"scale-down\")\n",
    "            \n",
    "    def process_example(img):\n",
    "        return img, [], None    \n",
    "    \n",
    "    def store_img(img):\n",
    "        return img, [], None  # when new image is uploaded, `selected_points` should be empty\n",
    "    input_image.upload(\n",
    "        store_img,\n",
    "        [input_image],\n",
    "        [original_image, selected_points, last_mask]\n",
    "    )\n",
    "    # user click the image to get points, and show the points on the image\n",
    "    def get_point(img, sel_pix, point_type, evt: gr.SelectData):\n",
    "        if point_type == 'foreground_point':\n",
    "            sel_pix.append((evt.index, 1))   # append the foreground_point\n",
    "        elif point_type == 'background_point':\n",
    "            sel_pix.append((evt.index, 0))    # append the background_point\n",
    "        else:\n",
    "            sel_pix.append((evt.index, 1))    # default foreground_point\n",
    "        # draw points\n",
    "        for point, label in sel_pix:\n",
    "            cv2.drawMarker(img, point, colors[label], markerType=markers[label], markerSize=20, thickness=5)\n",
    "        # if img[..., 0][0, 0] == img[..., 2][0, 0]:  # BGR to RGB\n",
    "        #     img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
    "        return img if isinstance(img, np.ndarray) else np.array(img)\n",
    "    \n",
    "    input_image.select(\n",
    "        get_point,\n",
    "        [input_image, selected_points, radio],\n",
    "        [input_image],\n",
    "    )\n",
    "\n",
    "    # undo the selected point\n",
    "    def undo_points(orig_img, sel_pix):\n",
    "        if isinstance(orig_img, int):   # if orig_img is int, the image if select from examples\n",
    "            temp = cv2.imread(image_examples[orig_img][0])\n",
    "            temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)\n",
    "        else:\n",
    "            temp = orig_img.copy()\n",
    "        # draw points\n",
    "        if len(sel_pix) != 0:\n",
    "            sel_pix.pop()\n",
    "            for point, label in sel_pix:\n",
    "                cv2.drawMarker(temp, point, colors[label], markerType=markers[label], markerSize=20, thickness=5)\n",
    "        if temp[..., 0][0, 0] == temp[..., 2][0, 0]:  # BGR to RGB\n",
    "            temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)\n",
    "        return temp, None if isinstance(temp, np.ndarray) else np.array(temp), None\n",
    "    \n",
    "    undo_button.click(\n",
    "        undo_points,\n",
    "        [original_image, selected_points],\n",
    "        [input_image, last_mask]\n",
    "    )\n",
    "\n",
    "    with gr.Row():\n",
    "        with gr.Column():\n",
    "            gr.Examples([\"data_demo/images/amos_0507_31.png\", \"data_demo/images/s0114_111.png\" ], inputs=[input_image], outputs=[original_image, selected_points,last_mask], fn=process_example, run_on_click=True)\n",
    "\n",
    "    button.click(fn=run_sammed, inputs=[original_image, selected_points, last_mask, adapter_type], outputs=[gallery_sammed, last_mask])\n",
    "\n",
    "block.launch(debug=True, share=True, show_error=True)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MMseg",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
